package spark.pipeline.KDD99

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.evaluation.ClusteringEvaluator
import org.apache.spark.ml.feature.{IndexToString, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.{SparkSession, functions}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType

import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable`

object Kdd99KMeans {
    case class kdd99(label: String, features: Vector, protocol_type: String, service: String, flag: String)

    case class featuresAll(features: Vector)

    def main(args: Array[String]): Unit = {
        Logger.getLogger("org.apache.spark").setLevel(Level.ERROR) //  启动这个就不会出现INFO列表,看起来整洁,但是就是可能什么都没有会有空虚感
        val spark = SparkSession.builder().appName("check K of KMeans").master("local").getOrCreate()
        val sc = spark.sparkContext
        import spark.implicits._

        val file = "C:/Users/Lenovo/Desktop/Working/Python/data/kddcup.data_10_percent_corrected.csv"
        val fileData = spark.read.csv(file)
        //      fileData.map(_.split(',').last).countByValue().toSeq.sortBy(_._2).reverse.foreach(println)

        val dataFrame = fileData
          .toDF("duration", "protocol_type", "service", "flag", "src_bytes", "dst_bytes", "land",
              "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
              "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
              "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
              "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
              "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
              "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
              "dst_host_srv_rerror_rate", "label")

        dataFrame.cache()
        dataFrame
          .show(5, false)

        val timeOld = System.currentTimeMillis()
        //修改"protocol_type"这一列为数值型
        val pretocol_indexer = new StringIndexer().setInputCol("protocol_type").setOutputCol("protocol_typeIndex").fit(dataFrame)
        val indexed_0 = pretocol_indexer.transform(dataFrame)

        //修改"service"这一列为数值型
        val service_indexer = new StringIndexer().setInputCol("service").setOutputCol("serviceIndex").fit(indexed_0)
        val indexed_1 = service_indexer.transform(indexed_0)

        //修改"flag"这一列为数值型
        val flag_indexer = new StringIndexer().setInputCol("flag").setOutputCol("flagIndex").fit(indexed_1)
        val indexed_2 = flag_indexer.transform(indexed_1)

        //修改"label"这一列为数值型
        val label_indexer = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(indexed_2)
        val indexed_df = label_indexer.transform(indexed_2)

        //    indexed_df.show(10,false)

        //删除原有的String类别列
        val df_final = indexed_df.drop("protocol_type").drop("service")
          .drop("flag").drop("label")

        val assembler = new VectorAssembler()
          .setInputCols(Array("duration", "src_bytes", "dst_bytes", "land",
              "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
              "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
              "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
              "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
              "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
              "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
              "dst_host_srv_rerror_rate", "protocol_typeIndex", "serviceIndex", "flagIndex"))
          .setOutputCol("features")

        val cols = df_final.columns.map(f => col(f).cast(DoubleType))
        val df_finalDataFrame = assembler.transform(df_final.select(cols: _*))

        //        df_finalDataFrame.show(10,false)
        val featuresDataFrame = df_finalDataFrame.select("labelIndex", "features")

        val labelBack = new IndexToString().setInputCol("labelIndex")
          .setOutputCol("label").setLabels(label_indexer.labels)

        val scaler = new StandardScaler().
          setInputCol("features").
          setOutputCol("scaledFeatureVector").
          setWithStd(true).
          setWithMean(false)

        val train = labelBack.transform(featuresDataFrame)
        val trainDF = train
          .na
          .replace("*", Map(
              "warezmaster." -> "R2L",
              "smurf." -> "DOS",
              "pod." -> "DOS",
              "imap." -> "R2L",
              "nmap." -> "Probing",
              "guess_passwd." -> "R2L",
              "ipsweep." -> "Probing",
              "portsweep." -> "Probing",
              "satan." -> "Probing",
              "land." -> "DOS",
              "loadmodule." -> "U2R",
              "ftp_write." -> "R2L",
              "buffer_overflow." -> "U2R",
              "rootkit." -> "U2R",
              "warezclient." -> "R2L",
              "teardrop." -> "DOS",
              "perl." -> "U2R",
              "phf." -> "R2L",
              "multihop." -> "R2L",
              "neptune." -> "DOS",
              "back." -> "DOS",
              "spy." -> "R2L"
          ))
        trainDF.show(5, false)

        //    val indexFeatures = new StringIndexer()
        //      .setInputCol("features")
        //      .setOutputCol("indexFeatures")
        //      .fit(trainDF)
        //    val Array(trainData, testData) = trainDF.randomSplit(Array(0.7, 0.3))

        //      计算大概为3小时,请谨慎运行

        println("模型训练------------------------------------------------------------------------------------------------------>")
        val Array(trainData, testData) = featuresDataFrame.randomSplit(Array(0.7, 0.3))
        val km = new KMeans()
          .setFeaturesCol("scaledFeatureVector")
          .setPredictionCol("prediction")
          .setMaxIter(40)
          .setK(80)

        val pipeline = new Pipeline()
          .setStages(Array(scaler, km))
        val prKmeans = pipeline.fit(trainData)

        val saveFile = "file:///C:/Users/Lenovo/Desktop/Working/Python/data/Kdd99KMeans"
        prKmeans.write.overwrite().save(saveFile)
        println("PipelineModel模型保存成功")
        //        加载模型
        //        val prediction = PipelineModel.load(saveFile)

        val predictionKMeans = prKmeans.transform(testData)

        val evaluator = new ClusteringEvaluator()
          .setPredictionCol("prediction")
          .setFeaturesCol("scaledFeatureVector")
        println("聚类时间花费为:" + (System.currentTimeMillis() - timeOld) / 1000 + "s")

        println("模型评估------------------------------------------------------------------------------------------------------>")
        println("k = 80 欧氏几何距离:" + evaluator.evaluate(predictionKMeans))

        val kMeansModel = prKmeans.stages.last.asInstanceOf[KMeansModel]
        val centroids = kMeansModel.clusterCenters //k个中心点
        //        centroids.foreach(println)

        //取出100个点向量与对应的prediction中心点的距离
        val threshold100 = predictionKMeans.select("prediction", "scaledFeatureVector").as[(Int, Vector)].
          map { case (cluster, vector) => Vectors.sqdist(centroids(cluster), vector) }.
          orderBy($"value".desc)

        val threshold100List = threshold100.collectAsList()
        val length = threshold100.count()
        var sum: Double = 0
        val jessica = threshold100List.takeRight(length.toInt - 10)
        println("删除top10数值")

        jessica.foreach(x => {
            sum = sum + x * x
        })

        println("总误差为:" + sum + " 测试机数量为:" + (length) + " 误差平方和为:" + sum / (length - 10))

        val threshold = threshold100.take(50).last
        println("质心距离阈值为:" + threshold)

        println("分类排名top50------------------------------------------------------------------------------------------------->")
        val tes = predictionKMeans
          .groupBy("labelIndex", "prediction")
          .count()
          .orderBy(-col("count"))
        tes.show(50, false)
        tes.groupBy("prediction")
          .max("count")
          .withColumnRenamed("max(count)","count")
          .agg(functions.sum("count"))
          .withColumnRenamed("sum(count)","preCount")
          .withColumn("sumCount",lit(length))
          .withColumn("Accuracy",col("preCount")/col("sumCount"))
          .show(false)

        println("模型预测------------------------------------------------------------------------------------------------------>")

        val predictionFile = "C:/Users/Lenovo/Desktop/Working/Python/data/kddcup.testdata.unlabeled_10_percent.csv"
        val predictionData = spark.read.csv(predictionFile)
        //        fileData.map(_.split(',').last).countByValue().toSeq.sortBy(_._2).reverse.foreach(println)

        val predictionDataFrame = predictionData
          .toDF("duration", "protocol_type", "service", "flag", "src_bytes", "dst_bytes", "land",
              "wrong_fragment", "urgent", "hot", "num_failed_logins", "logged_in", "num_compromised", "root_shell",
              "su_attempted", "num_root", "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds",
              "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", "srv_serror_rate", "rerror_rate",
              "srv_rerror_rate", "same_srv_rate", "diff_srv_rate", "srv_diff_host_rate", "dst_host_count",
              "dst_host_srv_count", "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate",
              "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", "dst_host_rerror_rate",
              "dst_host_srv_rerror_rate")
        predictionDataFrame.cache()

        val predictionIndexed_0 = pretocol_indexer.transform(predictionDataFrame)
        val predictionindexed_1 = service_indexer.transform(predictionIndexed_0)
        val predictionIndexed_3 = flag_indexer.transform(predictionindexed_1)

        val proserfla = predictionIndexed_3.select("protocol_type", "service", "flag")
        val predictionDf_final = predictionIndexed_3.drop("protocol_type").drop("service")
          .drop("flag")

        val predictionCols = predictionDf_final.columns.map(f => col(f).cast(DoubleType))
        val predictionDf_finalDataFrame = assembler.transform(predictionDf_final.select(predictionCols: _*))
        //        predictionDf_finalDataFrame.show(5,false)

        val preFeaturesDataFrame = predictionDf_finalDataFrame.select("features")

        val preDataFrame = prKmeans.transform(preFeaturesDataFrame)
        println("预测集合个数为:" + preDataFrame.count())
        preDataFrame.join(proserfla).show(20, false)


        println("异常检测------------------------------------------------------------------------------------------------------>")

        val abnormalDF = preDataFrame
          .filter(row => {
              val cluster = row.getAs[Integer]("prediction")
              val vec = row.getAs[Vector]("scaledFeatureVector")
              Vectors.sqdist(centroids(cluster), vec) > threshold
          })
        abnormalDF.show(20,false)
    }

}
